Diffusion Mapping¶
In [1]:
import numpy as np
x = [-10,0,10]
b = .03
# Create a 1000 x 3 matrix
output_matrix = np.zeros((1000, 3))
# Generate random values for each value in x
for i, value in enumerate(x):
output_matrix[0, i] = np.random.normal(x[i], np.sqrt(b), 1)
for i in range(1, 1000):
y = output_matrix[i-1, :]
output_matrix[i, 0] = np.random.normal(np.sqrt(1 - b) * y[0], np.sqrt(b), 1)
output_matrix[i, 1] = np.random.normal(np.sqrt(1 - b) * y[1], np.sqrt(b), 1)
output_matrix[i, 2] = np.random.normal(np.sqrt(1 - b) * y[2], np.sqrt(b), 1)
import matplotlib.pyplot as plt
# Create a time array from 0 to 999
time = range(1000)
# Plot each column of output_matrix
for i in range(3):
plt.plot(time, output_matrix[:, i], label=f'Value {i+1}')
# Add labels and legend
plt.xlabel('Time')
plt.ylabel('Value')
plt.legend()
# Show the plot
plt.show()
In [2]:
import numpy as np
tps = [1, 10, 25, 50, 100, 200]
# Compute (1 - b)^tps for each value in tps
bt = np.power(1 - b, tps)
import matplotlib.pyplot as plt
import seaborn as sns
# Create a 2 x 3 grid of subplots
fig, axs = plt.subplots(2, 3, figsize=(12, 8))
# Iterate over each tps value
for i, tp in enumerate(tps):
z = output_matrix[tp-1, :]
# Compute the mean and variance
mean = np.sqrt(bt[i]) * z
variance = 1 - bt[i]
# Generate normal random numbers
samples1 = np.random.normal(mean[0], np.sqrt(variance), 10000)
samples2 = np.random.normal(mean[1], np.sqrt(variance), 10000)
samples3 = np.random.normal(mean[2], np.sqrt(variance), 10000)
# Plot the density
row = i // 3
col = i % 3
sns.kdeplot(samples1, ax=axs[row, col])
sns.kdeplot(samples2, ax=axs[row, col])
sns.kdeplot(samples3, ax=axs[row, col])
axs[row, col].set_title(f'Step = {tp}')
# Adjust the spacing between subplots
plt.tight_layout()
# Show the plot
plt.show()
Simple Diffusion for Sampling from 9 Component Mixture¶
In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
# -----------------------------------------------------------------------------
# 1) Create your 2D 9‑Gaussian mixture dataset
# -----------------------------------------------------------------------------
centers = torch.tensor([[x, y] for x in (-2, 0, 2) for y in (-2, 0, 2)],
dtype=torch.float32)
n_samples = 10000
# sample mixture indices uniformly
idx = torch.randint(0, 9, (n_samples,))
# draw x0 ~ N(center[idx], 0.1^2 I)
x0 = centers[idx] + torch.randn(n_samples, 2) * 0.1
dataset = TensorDataset(x0)
loader = DataLoader(dataset, batch_size=256, shuffle=True)
In [2]:
import seaborn as sns
import matplotlib.pyplot as plt
# Extract x and y coordinates from x0
x = x0[:, 0].numpy()
y = x0[:, 1].numpy()
plt.figure(figsize=(6,6))
sns.kdeplot(x=x, y=y, fill=True, cmap="viridis")
plt.title("Density Plot for x0")
plt.xlabel("x coordinate")
plt.ylabel("y coordinate")
plt.show()
In [9]:
# -----------------------------------------------------------------------------
# 2) Diffusion schedule & forward q_sample
# -----------------------------------------------------------------------------
T = 500
beta = torch.linspace(1e-4, .05, T)
alpha = 1 - beta
alphac = torch.cumprod(alpha, dim=0)
def q_sample(x0, t, noise):
"""
x_t = sqrt(alpha_bar_t)*x0 + sqrt(1−alpha_bar_t)*noise
where
alpha_bar_t = alphac[t] shape [batch_size]
"""
# grab alpha_bar for each sample and unsqueeze so it’s [B,1]
a_bar = alphac[t].unsqueeze(-1) # now shape (B,1)
# and broadcast multiply against the 2‑D points
return torch.sqrt(a_bar) * x0 + torch.sqrt(1.0 - a_bar) * noise
In [10]:
import matplotlib.pyplot as plt
# Define the t values for which we'll generate q_sample images
t_values = [0, 100, 200, 300, 400]
# Create a constant noise tensor with value 0.01 (same shape as x0)
noise_const = torch.full_like(x0, 0)
# Create a subplot with one row and 5 columns
fig, axes = plt.subplots(1, 5, figsize=(20, 4))
for ax, t_val in zip(axes, t_values):
# Generate diffused samples at time t_val using q_sample
xt = q_sample(x0, torch.tensor(t_val), noise_const)
# Convert to numpy for plotting
xt_np = xt.detach().numpy()
# Create a scatter plot of the diffused points
ax.scatter(xt_np[:, 0], xt_np[:, 1], s=1, alpha=0.6)
ax.set_title(f"t = {t_val}")
ax.set_xlim([-3, 3])
ax.set_ylim([-3, 3])
ax.set_aspect('equal')
plt.tight_layout()
plt.show()
In [11]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# -------------------------------------------------------------------
# 1) Sinusoidal positional embedding for timesteps
# -------------------------------------------------------------------
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
"""
dim: total dimension of the output embedding (must be even)
"""
super().__init__()
self.dim = dim
def forward(self, t):
"""
t: tensor of shape [B] with integer timesteps in [0, T-1]
returns: [B, dim] sinusoidal embedding
"""
half = self.dim // 2
# frequencies: exp(-log(10000)*(0..half-1)/(half-1))
freqs = torch.exp(
-math.log(10000) * torch.arange(half, device=t.device).float() / (half - 1)
) # [half]
args = t.float().unsqueeze(1) * freqs.unsqueeze(0) # [B, half]
emb = torch.cat([args.sin(), args.cos()], dim=-1) # [B, dim]
return emb
# -------------------------------------------------------------------
# 2) Score network with explicit time embedding
# -------------------------------------------------------------------
class ScoreNet2D(nn.Module):
def __init__(self, temb_dim=64, hidden_dim=128):
"""
temb_dim: dimension of the sinusoidal time embedding
hidden_dim: hidden layer width
"""
super().__init__()
# map t -> sinusoidal embedding -> project
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(temb_dim),
nn.Linear(temb_dim, temb_dim),
nn.SiLU(),
)
# main MLP: input is (x:2 dims) + (temb_dim)
self.net = nn.Sequential(
nn.Linear(2 + temb_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, 2)
)
def forward(self, x, t):
"""
x: [B, 2] 2D points
t: [B] integer timesteps in {0,...,T-1}
"""
# 1) normalize t to [0,1]
t = t.float() / (T - 1)
# 2) get time embedding
temb = self.time_mlp(t) # [B, temb_dim]
# 3) concatenate and predict noise
inp = torch.cat([x, temb], dim=1) # [B, 2+temb_dim]
return self.net(inp) # [B, 2]
In [12]:
# -----------------------------------------------------------------------------
# 4) Reverse (ancestral) sampler
# -----------------------------------------------------------------------------
@torch.no_grad()
def sample(model, n_samples, device):
model.eval()
x = torch.randn(n_samples, 2, device=device)
for i in reversed(range(T)):
t = torch.full((n_samples,), i, device=device, dtype=torch.long)
eps = model(x, t).clamp(-5, 5)
b = beta[i]
a = alpha[i]
atil = alphac[i]
# ancestral update
mean = (1/math.sqrt(a)) * (x - (b / math.sqrt(1-atil)) * eps)
if i > 0:
noise = torch.randn_like(x)
x = mean + math.sqrt(b) * noise
else:
x = mean
return x.cpu()
In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
beta = beta.to(device)
alpha = alpha.to(device)
alphac = alphac.to(device)
model = ScoreNet2D(temb_dim=64, hidden_dim=128).to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
In [19]:
epochs = 500
for epoch in range(epochs):
running_loss = 0.0
for (x0_batch,) in loader:
x0_batch = x0_batch.to(device)
b = x0_batch.size(0)
t = torch.randint(0, T, (b,), device=device)
noise = torch.randn_like(x0_batch)
xt = q_sample(x0_batch, t, noise)
eps_pred = model(xt, t)
loss = F.mse_loss(eps_pred, noise)
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
running_loss += loss.item() * b
if (epoch+1) % 25 == 0 or epoch == 0:
avg_loss = running_loss / n_samples
print(f"Epoch {epoch+1:3d}/{epochs}, Loss: {avg_loss:.4f}")
gen = sample(model, 10000, device) # [10000,2]
gen = gen.numpy()
plt.figure(figsize=(6,6))
plt.hist2d(gen[:,0], gen[:,1], bins=100)
plt.title(f"Density of Generated Samples at Epoch {epoch+1}")
plt.axis('equal')
plt.show()
Epoch 1/500, Loss: 0.6537
Epoch 25/500, Loss: 0.2749
Epoch 50/500, Loss: 0.2534
Epoch 75/500, Loss: 0.2493
Epoch 100/500, Loss: 0.2259
Epoch 125/500, Loss: 0.2367
Epoch 150/500, Loss: 0.2162
Epoch 175/500, Loss: 0.2207
Epoch 200/500, Loss: 0.2114
Epoch 225/500, Loss: 0.2090
Epoch 250/500, Loss: 0.2205
Epoch 275/500, Loss: 0.2301
Epoch 300/500, Loss: 0.2164
Epoch 325/500, Loss: 0.2218
Epoch 350/500, Loss: 0.2287
Epoch 375/500, Loss: 0.2175
Epoch 400/500, Loss: 0.2117
Epoch 425/500, Loss: 0.2138
Epoch 450/500, Loss: 0.2094
Epoch 475/500, Loss: 0.2280
Epoch 500/500, Loss: 0.2273
Diffusion Model w/ Attention for CelebA¶
In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torchvision.utils as vutils
from tqdm import tqdm
# -----------------------------------------------------------------------------
# 1) Noise schedule + sampling helper
# -----------------------------------------------------------------------------
def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
return torch.linspace(beta_start, beta_end, timesteps)
class DiffusionSchedule:
def __init__(self, timesteps=1000):
self.timesteps = timesteps
self.betas = linear_beta_schedule(timesteps).to(device)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_acumprod = torch.sqrt(1 - self.alphas_cumprod)
def q_sample(self, x_start, t, noise):
"""
Diffuse x_start to x_t by adding noise at step t:
x_t = sqrt(ᾱ_t) * x_0 + sqrt(1 - ᾱ_t) * ε
"""
return (
self.sqrt_alphas_cumprod[t].view(-1,1,1,1) * x_start +
self.sqrt_one_minus_acumprod[t].view(-1,1,1,1) * noise
)
In [2]:
@torch.no_grad()
def p_sample_loop(model, diffusion, shape):
model.eval()
x = torch.randn(shape, device=device)
for i in reversed(range(diffusion.timesteps)):
# 1) prepare timestep tensor and predict ε
t = torch.full((shape[0],), i, device=device, dtype=torch.long)
eps_pred = model(x, t.float() / diffusion.timesteps)
# clamp huge predictions
eps_pred = eps_pred.clamp(-5.0, 5.0)
# 2) grab scalars for this step
beta_t = diffusion.betas[i]
alpha_t = diffusion.alphas[i]
alpha_bar_t = diffusion.alphas_cumprod[i]
# 3) stable sqrt’s
sqrt_alpha_t = torch.sqrt(alpha_t).clamp(min=1e-5)
sqrt_one_minus_ab = torch.sqrt(1 - alpha_bar_t).clamp(min=1e-5)
# 4) ancestral mean:
# μ = (1/√α_t) [ x_t − (β_t / √(1−ᾱ_t)) · ε_pred ]
mean = (1.0 / sqrt_alpha_t) * (
x - (beta_t / sqrt_one_minus_ab) * eps_pred
)
# 5) sample from p(x_{t−1}|x_t)
if i > 0:
noise = torch.randn_like(x)
sigma_t = torch.sqrt(beta_t)
x = mean + sigma_t * noise
else:
x = mean
model.train()
return x.clamp(-1, 1)
In [3]:
# -----------------------------------------------------------------------------
# 2) Time embedding, ResBlock, MHSA2d, Down/Up, UNetAttention
# -----------------------------------------------------------------------------
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
half = self.dim // 2
freqs = torch.exp(
-math.log(10000) * torch.arange(half, device=t.device).float() / (half - 1)
)
args = t.float()[:, None] * freqs[None]
return torch.cat([args.sin(), args.cos()], dim=-1)
class ResBlock(nn.Module):
def __init__(self, in_ch, out_ch, temb_dim=None, dropout=0.0):
super().__init__()
self.norm1 = nn.GroupNorm(8, in_ch)
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.norm2 = nn.GroupNorm(8, out_ch)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.emb_proj = nn.Linear(temb_dim, out_ch) if temb_dim is not None else None
self.res_conv = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
def forward(self, x, temb):
h = self.conv1(F.silu(self.norm1(x)))
if self.emb_proj is not None:
h = h + self.emb_proj(temb)[:, :, None, None]
h = self.conv2(self.dropout(F.silu(self.norm2(h))))
return h + self.res_conv(x)
In [4]:
class MHSA2d(nn.Module):
def __init__(self, channels, num_heads=4, dropout=0.0):
super().__init__()
self.norm = nn.GroupNorm(8, channels)
self.attn = nn.MultiheadAttention(
embed_dim=channels,
num_heads=num_heads,
dropout=dropout,
batch_first=False
)
self.proj = nn.Conv2d(channels, channels, 1)
def forward(self, x):
B,C,H,W = x.shape
h = self.norm(x).view(B, C, H*W).permute(2,0,1) # (S,B,C)
out, _ = self.attn(h, h, h)
out = out.permute(1,2,0).view(B,C,H,W)
return x + self.proj(out)
In [5]:
class Downsample(nn.Module):
def __init__(self, ch): super().__init__(); self.conv = nn.Conv2d(ch,ch,3,2,1)
def forward(self,x): return self.conv(x)
class Upsample(nn.Module):
def __init__(self,ch): super().__init__(); self.conv = nn.Conv2d(ch,ch,3,1,1)
def forward(self,x):
x = F.interpolate(x, scale_factor=2, mode="nearest")
return self.conv(x)
In [6]:
class UNetAttention(nn.Module):
def __init__(
self,
in_ch=3,
base_ch=128,
chan_mults=(1,1,2,2,4,4),
num_res_blocks=2,
temb_dim=512,
dropout=0.0
):
super().__init__()
self.num_res_blocks = num_res_blocks
# 1) Time‐step embedding MLP
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(temb_dim//2),
nn.Linear(temb_dim//2, temb_dim),
nn.SiLU(),
nn.Linear(temb_dim, temb_dim),
)
# 2) Initial conv
ch = base_ch
self.init_conv = nn.Conv2d(in_ch, ch, 3, padding=1)
# 3) Down path: build blocks & record skip channels
self.down_blocks = nn.ModuleList()
self.down_samples = nn.ModuleList()
skip_channels = []
for i, mult in enumerate(chan_mults):
out_ch = base_ch * mult
for _ in range(num_res_blocks):
self.down_blocks.append( ResBlock(ch, out_ch, temb_dim, dropout) )
skip_channels.append(out_ch) # record this for the up path
ch = out_ch
if i < len(chan_mults) - 1:
self.down_samples.append( Downsample(ch) )
# 4) Bottleneck
self.mid1 = ResBlock(ch, ch, temb_dim, dropout)
self.mid_attn = MHSA2d(ch, num_heads=4, dropout=dropout)
self.mid2 = ResBlock(ch, ch, temb_dim, dropout)
# 5) Up path: consume skip_channels in reverse
self.up_blocks = nn.ModuleList()
self.up_samples = nn.ModuleList()
skip_idx = len(skip_channels) - 1
for i, mult in reversed(list(enumerate(chan_mults))):
out_ch = base_ch * mult
# add an upsample layer if we're not at the very first (lowest) resolution
if i < len(chan_mults) - 1:
self.up_samples.append( Upsample(ch) )
# for each residual block in this stage, cat with a skip from down
for _ in range(num_res_blocks):
skip_ch = skip_channels[skip_idx]
skip_idx -= 1
self.up_blocks.append( ResBlock(ch + skip_ch, out_ch, temb_dim, dropout) )
ch = out_ch
# 6) Final norm → activation → 1×1 conv
self.final_norm = nn.GroupNorm(8, ch)
self.final_conv = nn.Conv2d(ch, in_ch, 1)
def forward(self, x, t):
# time embed
temb = self.time_mlp(t)
# down
h = self.init_conv(x)
skips = []
bi = 0
for block in self.down_blocks:
h = block(h, temb)
skips.append(h)
# after each group of num_res_blocks, do a downsample if available
if (bi + 1) % self.num_res_blocks == 0 and (bi // self.num_res_blocks) < len(self.down_samples):
h = self.down_samples[bi // self.num_res_blocks](h)
bi += 1
# bottleneck
h = self.mid1(h, temb)
h = self.mid_attn(h)
h = self.mid2(h, temb)
# 4) up path (fixed)
bi = 0
for block in self.up_blocks:
# only upsample at the start of each new resolution,
# and never at bi=0 (i.e. before the first two blocks).
if bi % self.num_res_blocks == 0 and bi > 0:
us_idx = bi // self.num_res_blocks - 1
h = self.up_samples[us_idx](h)
skip = skips.pop()
h = block(torch.cat([h, skip], dim=1), temb)
bi += 1
# 5) final norm → SiLU → conv
h = F.silu(self.final_norm(h))
return self.final_conv(h)
In [7]:
# -----------------------------------------------------------------------------
# 3) Data, model, optimizer, schedule (CelebA subset @ 128×128)
# -----------------------------------------------------------------------------
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset
class RawImageDataset(Dataset):
def __init__(self, root, transform=None, exts=(".png",".jpg",".jpeg")):
self.paths = [p for p in Path(root).iterdir() if p.suffix.lower() in exts]
self.transform = transform
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
img = Image.open(self.paths[idx]).convert("RGB")
if self.transform:
img = self.transform(img)
return img
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor(),
transforms.Normalize((0.5,)*3, (0.5,)*3),
])
train_ds = RawImageDataset("subset_images", transform=transform)
train_loader = DataLoader(
train_ds,
batch_size=32,
shuffle=True,
num_workers=4,
pin_memory=True
)
# model & diffusion
model = UNetAttention().to(device)
diffusion = DiffusionSchedule(timesteps=500)
optimizer = optim.Adam(model.parameters(), lr=2e-4)
In [8]:
from torchinfo import summary
# Create dummy inputs: a dummy image tensor and a dummy time tensor.
# The model expects x with shape (B,3,64,64) and t with shape (B,)
dummy_x = torch.randn(1, 3, 64, 64, device=device)
dummy_t = torch.tensor([0.0], device=device) # a dummy timestep (scaled value)
# Print the model summary
summary(model, input_data=(dummy_x, dummy_t))
Out[8]:
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== UNetAttention [1, 3, 64, 64] -- ├─Sequential: 1-1 [1, 512] -- │ └─SinusoidalPosEmb: 2-1 [1, 256] -- │ └─Linear: 2-2 [1, 512] 131,584 │ └─SiLU: 2-3 [1, 512] -- │ └─Linear: 2-4 [1, 512] 262,656 ├─Conv2d: 1-2 [1, 128, 64, 64] 3,584 ├─ModuleList: 1-13 -- (recursive) │ └─ResBlock: 2-5 [1, 128, 64, 64] -- │ │ └─GroupNorm: 3-1 [1, 128, 64, 64] 256 │ │ └─Conv2d: 3-2 [1, 128, 64, 64] 147,584 │ │ └─Linear: 3-3 [1, 128] 65,664 │ │ └─GroupNorm: 3-4 [1, 128, 64, 64] 256 │ │ └─Dropout: 3-5 [1, 128, 64, 64] -- │ │ └─Conv2d: 3-6 [1, 128, 64, 64] 147,584 │ │ └─Identity: 3-7 [1, 128, 64, 64] -- │ └─ResBlock: 2-6 [1, 128, 64, 64] -- │ │ └─GroupNorm: 3-8 [1, 128, 64, 64] 256 │ │ └─Conv2d: 3-9 [1, 128, 64, 64] 147,584 │ │ └─Linear: 3-10 [1, 128] 65,664 │ │ └─GroupNorm: 3-11 [1, 128, 64, 64] 256 │ │ └─Dropout: 3-12 [1, 128, 64, 64] -- │ │ └─Conv2d: 3-13 [1, 128, 64, 64] 147,584 │ │ └─Identity: 3-14 [1, 128, 64, 64] -- ├─ModuleList: 1-12 -- (recursive) │ └─Downsample: 2-7 [1, 128, 32, 32] -- │ │ └─Conv2d: 3-15 [1, 128, 32, 32] 147,584 ├─ModuleList: 1-13 -- (recursive) │ └─ResBlock: 2-8 [1, 128, 32, 32] -- │ │ └─GroupNorm: 3-16 [1, 128, 32, 32] 256 │ │ └─Conv2d: 3-17 [1, 128, 32, 32] 147,584 │ │ └─Linear: 3-18 [1, 128] 65,664 │ │ └─GroupNorm: 3-19 [1, 128, 32, 32] 256 │ │ └─Dropout: 3-20 [1, 128, 32, 32] -- │ │ └─Conv2d: 3-21 [1, 128, 32, 32] 147,584 │ │ └─Identity: 3-22 [1, 128, 32, 32] -- │ └─ResBlock: 2-9 [1, 128, 32, 32] -- │ │ └─GroupNorm: 3-23 [1, 128, 32, 32] 256 │ │ └─Conv2d: 3-24 [1, 128, 32, 32] 147,584 │ │ └─Linear: 3-25 [1, 128] 65,664 │ │ └─GroupNorm: 3-26 [1, 128, 32, 32] 256 │ │ └─Dropout: 3-27 [1, 128, 32, 32] -- │ │ └─Conv2d: 3-28 [1, 128, 32, 32] 147,584 │ │ └─Identity: 3-29 [1, 128, 32, 32] -- ├─ModuleList: 1-12 -- (recursive) │ └─Downsample: 2-10 [1, 128, 16, 16] -- │ │ └─Conv2d: 3-30 [1, 128, 16, 16] 147,584 ├─ModuleList: 1-13 -- (recursive) │ └─ResBlock: 2-11 [1, 256, 16, 16] -- │ │ └─GroupNorm: 3-31 [1, 128, 16, 16] 256 │ │ └─Conv2d: 3-32 [1, 256, 16, 16] 295,168 │ │ └─Linear: 3-33 [1, 256] 131,328 │ │ └─GroupNorm: 3-34 [1, 256, 16, 16] 512 │ │ └─Dropout: 3-35 [1, 256, 16, 16] -- │ │ └─Conv2d: 3-36 [1, 256, 16, 16] 590,080 │ │ └─Conv2d: 3-37 [1, 256, 16, 16] 33,024 │ └─ResBlock: 2-12 [1, 256, 16, 16] -- │ │ └─GroupNorm: 3-38 [1, 256, 16, 16] 512 │ │ └─Conv2d: 3-39 [1, 256, 16, 16] 590,080 │ │ └─Linear: 3-40 [1, 256] 131,328 │ │ └─GroupNorm: 3-41 [1, 256, 16, 16] 512 │ │ └─Dropout: 3-42 [1, 256, 16, 16] -- │ │ └─Conv2d: 3-43 [1, 256, 16, 16] 590,080 │ │ └─Identity: 3-44 [1, 256, 16, 16] -- ├─ModuleList: 1-12 -- (recursive) │ └─Downsample: 2-13 [1, 256, 8, 8] -- │ │ └─Conv2d: 3-45 [1, 256, 8, 8] 590,080 ├─ModuleList: 1-13 -- (recursive) │ └─ResBlock: 2-14 [1, 256, 8, 8] -- │ │ └─GroupNorm: 3-46 [1, 256, 8, 8] 512 │ │ └─Conv2d: 3-47 [1, 256, 8, 8] 590,080 │ │ └─Linear: 3-48 [1, 256] 131,328 │ │ └─GroupNorm: 3-49 [1, 256, 8, 8] 512 │ │ └─Dropout: 3-50 [1, 256, 8, 8] -- │ │ └─Conv2d: 3-51 [1, 256, 8, 8] 590,080 │ │ └─Identity: 3-52 [1, 256, 8, 8] -- │ └─ResBlock: 2-15 [1, 256, 8, 8] -- │ │ └─GroupNorm: 3-53 [1, 256, 8, 8] 512 │ │ └─Conv2d: 3-54 [1, 256, 8, 8] 590,080 │ │ └─Linear: 3-55 [1, 256] 131,328 │ │ └─GroupNorm: 3-56 [1, 256, 8, 8] 512 │ │ └─Dropout: 3-57 [1, 256, 8, 8] -- │ │ └─Conv2d: 3-58 [1, 256, 8, 8] 590,080 │ │ └─Identity: 3-59 [1, 256, 8, 8] -- ├─ModuleList: 1-12 -- (recursive) │ └─Downsample: 2-16 [1, 256, 4, 4] -- │ │ └─Conv2d: 3-60 [1, 256, 4, 4] 590,080 ├─ModuleList: 1-13 -- (recursive) │ └─ResBlock: 2-17 [1, 512, 4, 4] -- │ │ └─GroupNorm: 3-61 [1, 256, 4, 4] 512 │ │ └─Conv2d: 3-62 [1, 512, 4, 4] 1,180,160 │ │ └─Linear: 3-63 [1, 512] 262,656 │ │ └─GroupNorm: 3-64 [1, 512, 4, 4] 1,024 │ │ └─Dropout: 3-65 [1, 512, 4, 4] -- │ │ └─Conv2d: 3-66 [1, 512, 4, 4] 2,359,808 │ │ └─Conv2d: 3-67 [1, 512, 4, 4] 131,584 │ └─ResBlock: 2-18 [1, 512, 4, 4] -- │ │ └─GroupNorm: 3-68 [1, 512, 4, 4] 1,024 │ │ └─Conv2d: 3-69 [1, 512, 4, 4] 2,359,808 │ │ └─Linear: 3-70 [1, 512] 262,656 │ │ └─GroupNorm: 3-71 [1, 512, 4, 4] 1,024 │ │ └─Dropout: 3-72 [1, 512, 4, 4] -- │ │ └─Conv2d: 3-73 [1, 512, 4, 4] 2,359,808 │ │ └─Identity: 3-74 [1, 512, 4, 4] -- ├─ModuleList: 1-12 -- (recursive) │ └─Downsample: 2-19 [1, 512, 2, 2] -- │ │ └─Conv2d: 3-75 [1, 512, 2, 2] 2,359,808 ├─ModuleList: 1-13 -- (recursive) │ └─ResBlock: 2-20 [1, 512, 2, 2] -- │ │ └─GroupNorm: 3-76 [1, 512, 2, 2] 1,024 │ │ └─Conv2d: 3-77 [1, 512, 2, 2] 2,359,808 │ │ └─Linear: 3-78 [1, 512] 262,656 │ │ └─GroupNorm: 3-79 [1, 512, 2, 2] 1,024 │ │ └─Dropout: 3-80 [1, 512, 2, 2] -- │ │ └─Conv2d: 3-81 [1, 512, 2, 2] 2,359,808 │ │ └─Identity: 3-82 [1, 512, 2, 2] -- │ └─ResBlock: 2-21 [1, 512, 2, 2] -- │ │ └─GroupNorm: 3-83 [1, 512, 2, 2] 1,024 │ │ └─Conv2d: 3-84 [1, 512, 2, 2] 2,359,808 │ │ └─Linear: 3-85 [1, 512] 262,656 │ │ └─GroupNorm: 3-86 [1, 512, 2, 2] 1,024 │ │ └─Dropout: 3-87 [1, 512, 2, 2] -- │ │ └─Conv2d: 3-88 [1, 512, 2, 2] 2,359,808 │ │ └─Identity: 3-89 [1, 512, 2, 2] -- ├─ResBlock: 1-14 [1, 512, 2, 2] -- │ └─GroupNorm: 2-22 [1, 512, 2, 2] 1,024 │ └─Conv2d: 2-23 [1, 512, 2, 2] 2,359,808 │ └─Linear: 2-24 [1, 512] 262,656 │ └─GroupNorm: 2-25 [1, 512, 2, 2] 1,024 │ └─Dropout: 2-26 [1, 512, 2, 2] -- │ └─Conv2d: 2-27 [1, 512, 2, 2] 2,359,808 │ └─Identity: 2-28 [1, 512, 2, 2] -- ├─MHSA2d: 1-15 [1, 512, 2, 2] -- │ └─GroupNorm: 2-29 [1, 512, 2, 2] 1,024 │ └─MultiheadAttention: 2-30 [4, 1, 512] 1,050,624 │ └─Conv2d: 2-31 [1, 512, 2, 2] 262,656 ├─ResBlock: 1-16 [1, 512, 2, 2] -- │ └─GroupNorm: 2-32 [1, 512, 2, 2] 1,024 │ └─Conv2d: 2-33 [1, 512, 2, 2] 2,359,808 │ └─Linear: 2-34 [1, 512] 262,656 │ └─GroupNorm: 2-35 [1, 512, 2, 2] 1,024 │ └─Dropout: 2-36 [1, 512, 2, 2] -- │ └─Conv2d: 2-37 [1, 512, 2, 2] 2,359,808 │ └─Identity: 2-38 [1, 512, 2, 2] -- ├─ModuleList: 1-27 -- (recursive) │ └─ResBlock: 2-39 [1, 512, 2, 2] -- │ │ └─GroupNorm: 3-90 [1, 1024, 2, 2] 2,048 │ │ └─Conv2d: 3-91 [1, 512, 2, 2] 4,719,104 │ │ └─Linear: 3-92 [1, 512] 262,656 │ │ └─GroupNorm: 3-93 [1, 512, 2, 2] 1,024 │ │ └─Dropout: 3-94 [1, 512, 2, 2] -- │ │ └─Conv2d: 3-95 [1, 512, 2, 2] 2,359,808 │ │ └─Conv2d: 3-96 [1, 512, 2, 2] 524,800 │ └─ResBlock: 2-40 [1, 512, 2, 2] -- │ │ └─GroupNorm: 3-97 [1, 1024, 2, 2] 2,048 │ │ └─Conv2d: 3-98 [1, 512, 2, 2] 4,719,104 │ │ └─Linear: 3-99 [1, 512] 262,656 │ │ └─GroupNorm: 3-100 [1, 512, 2, 2] 1,024 │ │ └─Dropout: 3-101 [1, 512, 2, 2] -- │ │ └─Conv2d: 3-102 [1, 512, 2, 2] 2,359,808 │ │ └─Conv2d: 3-103 [1, 512, 2, 2] 524,800 ├─ModuleList: 1-26 -- (recursive) │ └─Upsample: 2-41 [1, 512, 4, 4] -- │ │ └─Conv2d: 3-104 [1, 512, 4, 4] 2,359,808 ├─ModuleList: 1-27 -- (recursive) │ └─ResBlock: 2-42 [1, 512, 4, 4] -- │ │ └─GroupNorm: 3-105 [1, 1024, 4, 4] 2,048 │ │ └─Conv2d: 3-106 [1, 512, 4, 4] 4,719,104 │ │ └─Linear: 3-107 [1, 512] 262,656 │ │ └─GroupNorm: 3-108 [1, 512, 4, 4] 1,024 │ │ └─Dropout: 3-109 [1, 512, 4, 4] -- │ │ └─Conv2d: 3-110 [1, 512, 4, 4] 2,359,808 │ │ └─Conv2d: 3-111 [1, 512, 4, 4] 524,800 │ └─ResBlock: 2-43 [1, 512, 4, 4] -- │ │ └─GroupNorm: 3-112 [1, 1024, 4, 4] 2,048 │ │ └─Conv2d: 3-113 [1, 512, 4, 4] 4,719,104 │ │ └─Linear: 3-114 [1, 512] 262,656 │ │ └─GroupNorm: 3-115 [1, 512, 4, 4] 1,024 │ │ └─Dropout: 3-116 [1, 512, 4, 4] -- │ │ └─Conv2d: 3-117 [1, 512, 4, 4] 2,359,808 │ │ └─Conv2d: 3-118 [1, 512, 4, 4] 524,800 ├─ModuleList: 1-26 -- (recursive) │ └─Upsample: 2-44 [1, 512, 8, 8] -- │ │ └─Conv2d: 3-119 [1, 512, 8, 8] 2,359,808 ├─ModuleList: 1-27 -- (recursive) │ └─ResBlock: 2-45 [1, 256, 8, 8] -- │ │ └─GroupNorm: 3-120 [1, 768, 8, 8] 1,536 │ │ └─Conv2d: 3-121 [1, 256, 8, 8] 1,769,728 │ │ └─Linear: 3-122 [1, 256] 131,328 │ │ └─GroupNorm: 3-123 [1, 256, 8, 8] 512 │ │ └─Dropout: 3-124 [1, 256, 8, 8] -- │ │ └─Conv2d: 3-125 [1, 256, 8, 8] 590,080 │ │ └─Conv2d: 3-126 [1, 256, 8, 8] 196,864 │ └─ResBlock: 2-46 [1, 256, 8, 8] -- │ │ └─GroupNorm: 3-127 [1, 512, 8, 8] 1,024 │ │ └─Conv2d: 3-128 [1, 256, 8, 8] 1,179,904 │ │ └─Linear: 3-129 [1, 256] 131,328 │ │ └─GroupNorm: 3-130 [1, 256, 8, 8] 512 │ │ └─Dropout: 3-131 [1, 256, 8, 8] -- │ │ └─Conv2d: 3-132 [1, 256, 8, 8] 590,080 │ │ └─Conv2d: 3-133 [1, 256, 8, 8] 131,328 ├─ModuleList: 1-26 -- (recursive) │ └─Upsample: 2-47 [1, 256, 16, 16] -- │ │ └─Conv2d: 3-134 [1, 256, 16, 16] 590,080 ├─ModuleList: 1-27 -- (recursive) │ └─ResBlock: 2-48 [1, 256, 16, 16] -- │ │ └─GroupNorm: 3-135 [1, 512, 16, 16] 1,024 │ │ └─Conv2d: 3-136 [1, 256, 16, 16] 1,179,904 │ │ └─Linear: 3-137 [1, 256] 131,328 │ │ └─GroupNorm: 3-138 [1, 256, 16, 16] 512 │ │ └─Dropout: 3-139 [1, 256, 16, 16] -- │ │ └─Conv2d: 3-140 [1, 256, 16, 16] 590,080 │ │ └─Conv2d: 3-141 [1, 256, 16, 16] 131,328 │ └─ResBlock: 2-49 [1, 256, 16, 16] -- │ │ └─GroupNorm: 3-142 [1, 512, 16, 16] 1,024 │ │ └─Conv2d: 3-143 [1, 256, 16, 16] 1,179,904 │ │ └─Linear: 3-144 [1, 256] 131,328 │ │ └─GroupNorm: 3-145 [1, 256, 16, 16] 512 │ │ └─Dropout: 3-146 [1, 256, 16, 16] -- │ │ └─Conv2d: 3-147 [1, 256, 16, 16] 590,080 │ │ └─Conv2d: 3-148 [1, 256, 16, 16] 131,328 ├─ModuleList: 1-26 -- (recursive) │ └─Upsample: 2-50 [1, 256, 32, 32] -- │ │ └─Conv2d: 3-149 [1, 256, 32, 32] 590,080 ├─ModuleList: 1-27 -- (recursive) │ └─ResBlock: 2-51 [1, 128, 32, 32] -- │ │ └─GroupNorm: 3-150 [1, 384, 32, 32] 768 │ │ └─Conv2d: 3-151 [1, 128, 32, 32] 442,496 │ │ └─Linear: 3-152 [1, 128] 65,664 │ │ └─GroupNorm: 3-153 [1, 128, 32, 32] 256 │ │ └─Dropout: 3-154 [1, 128, 32, 32] -- │ │ └─Conv2d: 3-155 [1, 128, 32, 32] 147,584 │ │ └─Conv2d: 3-156 [1, 128, 32, 32] 49,280 │ └─ResBlock: 2-52 [1, 128, 32, 32] -- │ │ └─GroupNorm: 3-157 [1, 256, 32, 32] 512 │ │ └─Conv2d: 3-158 [1, 128, 32, 32] 295,040 │ │ └─Linear: 3-159 [1, 128] 65,664 │ │ └─GroupNorm: 3-160 [1, 128, 32, 32] 256 │ │ └─Dropout: 3-161 [1, 128, 32, 32] -- │ │ └─Conv2d: 3-162 [1, 128, 32, 32] 147,584 │ │ └─Conv2d: 3-163 [1, 128, 32, 32] 32,896 ├─ModuleList: 1-26 -- (recursive) │ └─Upsample: 2-53 [1, 128, 64, 64] -- │ │ └─Conv2d: 3-164 [1, 128, 64, 64] 147,584 ├─ModuleList: 1-27 -- (recursive) │ └─ResBlock: 2-54 [1, 128, 64, 64] -- │ │ └─GroupNorm: 3-165 [1, 256, 64, 64] 512 │ │ └─Conv2d: 3-166 [1, 128, 64, 64] 295,040 │ │ └─Linear: 3-167 [1, 128] 65,664 │ │ └─GroupNorm: 3-168 [1, 128, 64, 64] 256 │ │ └─Dropout: 3-169 [1, 128, 64, 64] -- │ │ └─Conv2d: 3-170 [1, 128, 64, 64] 147,584 │ │ └─Conv2d: 3-171 [1, 128, 64, 64] 32,896 │ └─ResBlock: 2-55 [1, 128, 64, 64] -- │ │ └─GroupNorm: 3-172 [1, 256, 64, 64] 512 │ │ └─Conv2d: 3-173 [1, 128, 64, 64] 295,040 │ │ └─Linear: 3-174 [1, 128] 65,664 │ │ └─GroupNorm: 3-175 [1, 128, 64, 64] 256 │ │ └─Dropout: 3-176 [1, 128, 64, 64] -- │ │ └─Conv2d: 3-177 [1, 128, 64, 64] 147,584 │ │ └─Conv2d: 3-178 [1, 128, 64, 64] 32,896 ├─GroupNorm: 1-28 [1, 128, 64, 64] 256 ├─Conv2d: 1-29 [1, 3, 64, 64] 387 ========================================================================================== Total params: 89,488,131 Trainable params: 89,488,131 Non-trainable params: 0 Total mult-adds (G): 12.34 ========================================================================================== Input size (MB): 0.05 Forward/backward pass size (MB): 138.44 Params size (MB): 353.75 Estimated Total Size (MB): 492.24 ==========================================================================================
In [8]:
# -----------------------------------------------------------------------------
# 4) Training loop
# -----------------------------------------------------------------------------
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
epochs = 50
for epoch in range(epochs):
pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
for x in pbar:
x = x.to(device)
bsz = x.size(0)
# sample t ∈ [0,T) and noise
t = torch.randint(0, diffusion.timesteps, (bsz,), device=device)
noise = torch.randn_like(x)
# diffuse
x_t = diffusion.q_sample(x, t, noise)
# predict noise with scaled t
noise_pred = model(x_t, t.float() / diffusion.timesteps)
# compute MSE loss
loss = F.mse_loss(noise_pred, noise)
optimizer.zero_grad()
loss.backward()
# <-- clamp gradients to prevent explosion -->
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
pbar.set_postfix(loss=loss.item())
# display samples every 10 epochs
if epoch % 1 == 0:
samples = p_sample_loop(model, diffusion, (25,3,64,64))
grid = make_grid((samples + 1)*0.5, nrow=5).clamp(0,1)
np_grid = grid.permute(1,2,0).cpu().numpy()
plt.figure(figsize=(5,5))
plt.imshow(np_grid)
plt.axis('off')
plt.show()
Epoch 0: 100%|██████████| 469/469 [01:10<00:00, 6.67it/s, loss=0.0408]
Epoch 1: 100%|██████████| 469/469 [01:10<00:00, 6.63it/s, loss=0.0827]
Epoch 2: 100%|██████████| 469/469 [01:07<00:00, 6.94it/s, loss=0.0295]
Epoch 3: 100%|██████████| 469/469 [01:05<00:00, 7.21it/s, loss=0.0306]
Epoch 4: 100%|██████████| 469/469 [01:05<00:00, 7.18it/s, loss=0.0329]
Epoch 5: 100%|██████████| 469/469 [01:10<00:00, 6.65it/s, loss=0.0337]
Epoch 6: 100%|██████████| 469/469 [01:09<00:00, 6.72it/s, loss=0.0286]
Epoch 7: 100%|██████████| 469/469 [01:10<00:00, 6.65it/s, loss=0.0423]
Epoch 8: 100%|██████████| 469/469 [01:05<00:00, 7.11it/s, loss=0.0208]
Epoch 9: 100%|██████████| 469/469 [01:04<00:00, 7.25it/s, loss=0.02]
Epoch 10: 100%|██████████| 469/469 [01:04<00:00, 7.27it/s, loss=0.0108]
Epoch 11: 100%|██████████| 469/469 [01:04<00:00, 7.24it/s, loss=0.0391]
Epoch 12: 100%|██████████| 469/469 [01:04<00:00, 7.24it/s, loss=0.0166]
Epoch 13: 100%|██████████| 469/469 [01:08<00:00, 6.87it/s, loss=0.0182]
Epoch 14: 100%|██████████| 469/469 [01:10<00:00, 6.69it/s, loss=0.0231]
Epoch 15: 100%|██████████| 469/469 [01:09<00:00, 6.71it/s, loss=0.032]
Epoch 16: 100%|██████████| 469/469 [01:10<00:00, 6.68it/s, loss=0.0333]
Epoch 17: 100%|██████████| 469/469 [01:09<00:00, 6.72it/s, loss=0.0173]
Epoch 18: 100%|██████████| 469/469 [01:10<00:00, 6.65it/s, loss=0.0289]
Epoch 19: 100%|██████████| 469/469 [01:10<00:00, 6.63it/s, loss=0.0182]
Epoch 20: 100%|██████████| 469/469 [01:09<00:00, 6.71it/s, loss=0.0343]
Epoch 21: 100%|██████████| 469/469 [01:10<00:00, 6.67it/s, loss=0.0227]
Epoch 22: 19%|█▊ | 87/469 [00:13<00:58, 6.58it/s, loss=0.0324]
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) Cell In[8], line 27 25 loss = F.mse_loss(noise_pred, noise) 26 optimizer.zero_grad() ---> 27 loss.backward() 29 # <-- clamp gradients to prevent explosion --> 30 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) File ~/.local/lib/python3.10/site-packages/torch/_tensor.py:581, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs) 571 if has_torch_function_unary(self): 572 return handle_torch_function( 573 Tensor.backward, 574 (self,), (...) 579 inputs=inputs, 580 ) --> 581 torch.autograd.backward( 582 self, gradient, retain_graph, create_graph, inputs=inputs 583 ) File ~/.local/lib/python3.10/site-packages/torch/autograd/__init__.py:347, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs) 342 retain_graph = create_graph 344 # The reason we repeat the same comment below is that 345 # some Python versions print out the first line of a multi-line function 346 # calls in the traceback and some print out the last line --> 347 _engine_run_backward( 348 tensors, 349 grad_tensors_, 350 retain_graph, 351 create_graph, 352 inputs, 353 allow_unreachable=True, 354 accumulate_grad=True, 355 ) File ~/.local/lib/python3.10/site-packages/torch/autograd/graph.py:825, in _engine_run_backward(t_outputs, *args, **kwargs) 823 unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) 824 try: --> 825 return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass 826 t_outputs, *args, **kwargs 827 ) # Calls into the C++ engine to run the backward pass 828 finally: 829 if attach_logging_hooks: KeyboardInterrupt:
In [11]:
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
# 1) Draw 100 samples from pure noise → reverse diffusion
# shape = (N, C, H, W) = (100, 3, 128, 128)
samples = p_sample_loop(model, diffusion, (225, 3, 64, 64))
# 2) Un‑normalize from [–1,1] → [0,1]
samples = (samples + 1) * 0.5
samples = samples.clamp(0, 1)
# 3) Make a grid: 10 images per row → 10×10
grid = make_grid(samples, nrow=15)
# 4) Plot
plt.figure(figsize=(25,25))
# permute to H×W×C for plt
plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
plt.axis("off")
plt.title("Generated Faces")
plt.show()
Stable Diffusion¶
In [1]:
# 1) Install the libraries (run once in your environment)
! pip install diffusers transformers accelerate safetensors --quiet
[notice] A new release of pip is available: 24.0 -> 25.0.1 [notice] To update, run: pip install --upgrade pip
In [1]:
from huggingface_hub import login
# Read the token from the file
with open("/home/kmcalist/QTM447/HFToken.txt", "r") as token_file:
hf_token = token_file.read().strip()
# Authenticate using the token
login(hf_token)
In [2]:
import torch
from diffusers import StableDiffusion3Pipeline
import matplotlib.pyplot as plt
model_id = "stabilityai/stable-diffusion-3.5-medium"
pipe = StableDiffusion3Pipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16, # MMDiT performs best in bf16
safety_checker=None # if you want to skip the NSFW filter
)
pipe = pipe.to("cuda")
2025-04-17 12:38:41.666594: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-17 12:38:41.666619: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-17 12:38:41.667606: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-17 12:38:41.672577: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-04-17 12:38:42.496755: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Keyword arguments {'safety_checker': None} are not expected by StableDiffusion3Pipeline and will be ignored.
Loading pipeline components...: 0%| | 0/9 [00:00<?, ?it/s]
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]
In [4]:
# 3) Generate some images
prompts = [
"Just a regular old celebrity wearing a hat in a realistic style.",
"A cyberpunk dog samurai standing in neon rain",
"15 students sitting in a classroom",
"Dog pope"
]
# before you start, tell the pipeline to slice attention to save a bit more
pipe.enable_attention_slicing()
for prompt in prompts:
# run just one prompt
out = pipe(prompt, num_inference_steps=30, guidance_scale=7.5)
img = out.images[0]
# display immediately
plt.figure(figsize=(4,4))
plt.imshow(img)
plt.axis("off")
plt.show()
# clean up
del out, img
torch.cuda.empty_cache()
0%| | 0/30 [00:00<?, ?it/s]
0%| | 0/30 [00:00<?, ?it/s]
0%| | 0/30 [00:00<?, ?it/s]
0%| | 0/30 [00:00<?, ?it/s]